/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * SMOTE.java
 * 
 * Copyright (C) 2008 Ryan Lichtenwalter 
 * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
 */

package weka.filters.supervised.instance;

import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.Vector;

import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.SupervisedFilter;

/**
 * <!-- globalinfo-start --> Resamples a dataset by applying the Synthetic
 * Minority Oversampling TEchnique (SMOTE). The original dataset must fit
 * entirely in memory. The amount of SMOTE and number of nearest neighbors may
 * be specified. For more information, see <br/>
 * <br/>
 * Nitesh V. Chawla et. al. (2002). Synthetic Minority Over-sampling Technique.
 * Journal of Artificial Intelligence Research. 16:321-357.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- technical-bibtex-start --> BibTeX:
 * 
 * <pre>
 * &#64;article{al.2002,
 *    author = {Nitesh V. Chawla et. al.},
 *    journal = {Journal of Artificial Intelligence Research},
 *    pages = {321-357},
 *    title = {Synthetic Minority Over-sampling Technique},
 *    volume = {16},
 *    year = {2002}
 * }
 * </pre>
 * <p/>
 * <!-- technical-bibtex-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -S &lt;num&gt;
 *  Specifies the random number seed
 *  (default 1)
 * </pre>
 * 
 * <pre>
 * -P &lt;percentage&gt;
 *  Specifies percentage of SMOTE instances to create.
 *  (default 100.0)
 * </pre>
 * 
 * <pre>
 * -K &lt;nearest-neighbors&gt;
 *  Specifies the number of nearest neighbors to use.
 *  (default 5)
 * </pre>
 * 
 * <pre>
 * -C &lt;value-index&gt;
 *  Specifies the index of the nominal class value to SMOTE
 *  (default 0: auto-detect non-empty minority class))
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Ryan Lichtenwalter (rlichtenwalter@gmail.com)
 * @version $Revision: 9657 $
 */
public class SMOTE extends Filter implements SupervisedFilter, OptionHandler,
    TechnicalInformationHandler {

  /** for serialization. */
  static final long serialVersionUID = -1653880819059250364L;

  /** the number of neighbors to use. */
  protected int m_NearestNeighbors = 5;

  /** the random seed to use. */
  protected int m_RandomSeed = 1;

  /** the percentage of SMOTE instances to create. */
  protected double m_Percentage = 100.0;

  /** the index of the class value. */
  protected String m_ClassValueIndex = "0";

  /** whether to detect the minority class automatically. */
  protected boolean m_DetectMinorityClass = true;

  /**
   * Returns a string describing this classifier.
   * 
   * @return a description of the classifier suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String globalInfo() {
    return "Resamples a dataset by applying the Synthetic Minority Oversampling TEchnique (SMOTE)."
        + " The original dataset must fit entirely in memory."
        + " The amount of SMOTE and number of nearest neighbors may be specified."
        + " For more information, see \n\n"
        + getTechnicalInformation().toString();
  }

  /**
   * Returns an instance of a TechnicalInformation object, containing detailed
   * information about the technical background of this class, e.g., paper
   * reference or book this class is based on.
   * 
   * @return the technical information about this class
   */
  public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation result = new TechnicalInformation(Type.ARTICLE);

    result.setValue(Field.AUTHOR, "Nitesh V. Chawla et. al.");
    result.setValue(Field.TITLE, "Synthetic Minority Over-sampling Technique");
    result.setValue(Field.JOURNAL,
        "Journal of Artificial Intelligence Research");
    result.setValue(Field.YEAR, "2002");
    result.setValue(Field.VOLUME, "16");
    result.setValue(Field.PAGES, "321-357");

    return result;
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 9657 $");
  }

  /**
   * Returns the Capabilities of this filter.
   * 
   * @return the capabilities of this object
   * @see Capabilities
   */
  
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();
    result.disableAll();

    // attributes
    result.enableAllAttributes();
    result.enable(Capability.MISSING_VALUES);

    // class
    result.enable(Capability.NOMINAL_CLASS);
    result.enable(Capability.MISSING_CLASS_VALUES);

    return result;
  }

  /**
   * Returns an enumeration describing the available options.
   * 
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector newVector = new Vector();

    newVector.addElement(new Option("\tSpecifies the random number seed\n"
        + "\t(default 1)", "S", 1, "-S <num>"));

    newVector.addElement(new Option(
        "\tSpecifies percentage of SMOTE instances to create.\n"
            + "\t(default 100.0)\n", "P", 1, "-P <percentage>"));

    newVector.addElement(new Option(
        "\tSpecifies the number of nearest neighbors to use.\n"
            + "\t(default 5)\n", "K", 1, "-K <nearest-neighbors>"));

    newVector.addElement(new Option(
        "\tSpecifies the index of the nominal class value to SMOTE\n"
            + "\t(default 0: auto-detect non-empty minority class))\n", "C", 1,
        "-C <value-index>"));

    return newVector.elements();
  }

  /**
   * Parses a given list of options.
   * 
   * <!-- options-start --> Valid options are:
   * <p/>
   * 
   * <pre>
   * -S &lt;num&gt;
   *  Specifies the random number seed
   *  (default 1)
   * </pre>
   * 
   * <pre>
   * -P &lt;percentage&gt;
   *  Specifies percentage of SMOTE instances to create.
   *  (default 100.0)
   * </pre>
   * 
   * <pre>
   * -K &lt;nearest-neighbors&gt;
   *  Specifies the number of nearest neighbors to use.
   *  (default 5)
   * </pre>
   * 
   * <pre>
   * -C &lt;value-index&gt;
   *  Specifies the index of the nominal class value to SMOTE
   *  (default 0: auto-detect non-empty minority class))
   * </pre>
   * 
   * <!-- options-end -->
   * 
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    String seedStr = Utils.getOption('S', options);
    if (seedStr.length() != 0) {
      setRandomSeed(Integer.parseInt(seedStr));
    } else {
      setRandomSeed(1);
    }

    String percentageStr = Utils.getOption('P', options);
    if (percentageStr.length() != 0) {
      setPercentage(new Double(percentageStr).doubleValue());
    } else {
      setPercentage(100.0);
    }

    String nnStr = Utils.getOption('K', options);
    if (nnStr.length() != 0) {
      setNearestNeighbors(Integer.parseInt(nnStr));
    } else {
      setNearestNeighbors(5);
    }

    String classValueIndexStr = Utils.getOption('C', options);
    if (classValueIndexStr.length() != 0) {
      setClassValue(classValueIndexStr);
    } else {
      m_DetectMinorityClass = true;
    }
  }

  /**
   * Gets the current settings of the filter.
   * 
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {
    Vector<String> result;

    result = new Vector<String>();

    result.add("-C");
    result.add(getClassValue());

    result.add("-K");
    result.add("" + getNearestNeighbors());

    result.add("-P");
    result.add("" + getPercentage());

    result.add("-S");
    result.add("" + getRandomSeed());

    return result.toArray(new String[result.size()]);
  }

  /**
   * Returns the tip text for this property.
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String randomSeedTipText() {
    return "The seed used for random sampling.";
  }

  /**
   * Gets the random number seed.
   * 
   * @return the random number seed.
   */
  public int getRandomSeed() {
    return m_RandomSeed;
  }

  /**
   * Sets the random number seed.
   * 
   * @param value the new random number seed.
   */
  public void setRandomSeed(int value) {
    m_RandomSeed = value;
  }

  /**
   * Returns the tip text for this property.
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String percentageTipText() {
    return "The percentage of SMOTE instances to create.";
  }

  /**
   * Sets the percentage of SMOTE instances to create.
   * 
   * @param value the percentage to use
   */
  public void setPercentage(double value) {
    if (value >= 0)
      m_Percentage = value;
    else
      System.err.println("Percentage must be >= 0!");
  }

  /**
   * Gets the percentage of SMOTE instances to create.
   * 
   * @return the percentage of SMOTE instances to create
   */
  public double getPercentage() {
    return m_Percentage;
  }

  /**
   * Returns the tip text for this property.
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String nearestNeighborsTipText() {
    return "The number of nearest neighbors to use.";
  }

  /**
   * Sets the number of nearest neighbors to use.
   * 
   * @param value the number of nearest neighbors to use
   */
  public void setNearestNeighbors(int value) {
    if (value >= 1)
      m_NearestNeighbors = value;
    else
      System.err.println("At least 1 neighbor necessary!");
  }

  /**
   * Gets the number of nearest neighbors to use.
   * 
   * @return the number of nearest neighbors to use
   */
  public int getNearestNeighbors() {
    return m_NearestNeighbors;
  }

  /**
   * Returns the tip text for this property.
   * 
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String classValueTipText() {
    return "The index of the class value to which SMOTE should be applied. "
        + "Use a value of 0 to auto-detect the non-empty minority class.";
  }

  /**
   * Sets the index of the class value to which SMOTE should be applied.
   * 
   * @param value the class value index
   */
  public void setClassValue(String value) {
    m_ClassValueIndex = value;
    if (m_ClassValueIndex.equals("0")) {
      m_DetectMinorityClass = true;
    } else {
      m_DetectMinorityClass = false;
    }
  }

  /**
   * Gets the index of the class value to which SMOTE should be applied.
   * 
   * @return the index of the clas value to which SMOTE should be applied
   */
  public String getClassValue() {
    return m_ClassValueIndex;
  }

  /**
   * Sets the format of the input instances.
   * 
   * @param instanceInfo an Instances object containing the input instance
   *          structure (any instances contained in the object are ignored -
   *          only the structure is required).
   * @return true if the outputFormat may be collected immediately
   * @throws Exception if the input format can't be set successfully
   */
  
  public boolean setInputFormat(Instances instanceInfo) throws Exception {
    super.setInputFormat(instanceInfo);
    super.setOutputFormat(instanceInfo);
    return true;
  }

  /**
   * Input an instance for filtering. Filter requires all training instances be
   * read before producing output.
   * 
   * @param instance the input instance
   * @return true if the filtered instance may now be collected with output().
   * @throws IllegalStateException if no input structure has been defined
   */
  
  public boolean input(Instance instance) {
    if (getInputFormat() == null) {
      throw new IllegalStateException("No input instance format defined");
    }
    if (m_NewBatch) {
      resetQueue();
      m_NewBatch = false;
    }
    if (m_FirstBatchDone) {
      push(instance);
      return true;
    } else {
      bufferInput(instance);
      return false;
    }
  }

  /**
   * Signify that this batch of input to the filter is finished. If the filter
   * requires all instances prior to filtering, output() may now be called to
   * retrieve the filtered instances.
   * 
   * @return true if there are instances pending output
   * @throws IllegalStateException if no input structure has been defined
   * @throws Exception if provided options cannot be executed on input instances
   */
  
  public boolean batchFinished() throws Exception {
    if (getInputFormat() == null) {
      throw new IllegalStateException("No input instance format defined");
    }

    if (!m_FirstBatchDone) {
      // Do SMOTE, and clear the input instances.
      doSMOTE();
    }
    flushInput();

    m_NewBatch = true;
    m_FirstBatchDone = true;
    return (numPendingOutput() != 0);
  }

  /**
   * The procedure implementing the SMOTE algorithm. The output instances are
   * pushed onto the output queue for collection.
   * 
   * @throws Exception if provided options cannot be executed on input instances
   */
  protected void doSMOTE() throws Exception {
    int minIndex = 0;
    int min = Integer.MAX_VALUE;
    if (m_DetectMinorityClass) {
      // find minority class
      int[] classCounts = getInputFormat().attributeStats(
          getInputFormat().classIndex()).nominalCounts;
      for (int i = 0; i < classCounts.length; i++) {
        if (classCounts[i] != 0 && classCounts[i] < min) {
          min = classCounts[i];
          minIndex = i;
        }
      }
    } else {
      String classVal = getClassValue();
      if (classVal.equalsIgnoreCase("first")) {
        minIndex = 1;
      } else if (classVal.equalsIgnoreCase("last")) {
        minIndex = getInputFormat().numClasses();
      } else {
        minIndex = Integer.parseInt(classVal);
      }
      if (minIndex > getInputFormat().numClasses()) {
        throw new Exception("value index must be <= the number of classes");
      }
      minIndex--; // make it an index
    }

    int nearestNeighbors;
    if (min <= getNearestNeighbors()) {
      nearestNeighbors = min - 1;
    } else {
      nearestNeighbors = getNearestNeighbors();
    }
    if (nearestNeighbors < 1)
      throw new Exception("Cannot use 0 neighbors!");

    // compose minority class dataset
    // also push all dataset instances
    Instances sample = getInputFormat().stringFreeStructure();
    Enumeration instanceEnum = getInputFormat().enumerateInstances();
    while (instanceEnum.hasMoreElements()) {
      Instance instance = (Instance) instanceEnum.nextElement();
      push((Instance) instance.copy());
      if ((int) instance.classValue() == minIndex) {
        sample.add(instance);
      }
    }

    // compute Value Distance Metric matrices for nominal features
    Map vdmMap = new HashMap();
    Enumeration attrEnum = getInputFormat().enumerateAttributes();
    while (attrEnum.hasMoreElements()) {
      Attribute attr = (Attribute) attrEnum.nextElement();
      if (!attr.equals(getInputFormat().classAttribute())) {
        if (attr.isNominal() || attr.isString()) {
          double[][] vdm = new double[attr.numValues()][attr.numValues()];
          vdmMap.put(attr, vdm);
          int[] featureValueCounts = new int[attr.numValues()];
          int[][] featureValueCountsByClass = new int[getInputFormat()
              .classAttribute().numValues()][attr.numValues()];
          instanceEnum = getInputFormat().enumerateInstances();
          while (instanceEnum.hasMoreElements()) {
            Instance instance = (Instance) instanceEnum.nextElement();
            int value = (int) instance.value(attr);
            int classValue = (int) instance.classValue();
            featureValueCounts[value]++;
            featureValueCountsByClass[classValue][value]++;
          }
          for (int valueIndex1 = 0; valueIndex1 < attr.numValues(); valueIndex1++) {
            for (int valueIndex2 = 0; valueIndex2 < attr.numValues(); valueIndex2++) {
              double sum = 0;
              for (int classValueIndex = 0; classValueIndex < getInputFormat()
                  .numClasses(); classValueIndex++) {
                double c1i = featureValueCountsByClass[classValueIndex][valueIndex1];
                double c2i = featureValueCountsByClass[classValueIndex][valueIndex2];
                double c1 = featureValueCounts[valueIndex1];
                double c2 = featureValueCounts[valueIndex2];
                double term1 = c1i / c1;
                double term2 = c2i / c2;
                sum += Math.abs(term1 - term2);
              }
              vdm[valueIndex1][valueIndex2] = sum;
            }
          }
        }
      }
    }

    // use this random source for all required randomness
    Random rand = new Random(getRandomSeed());

    // find the set of extra indices to use if the percentage is not evenly
    // divisible by 100
    List extraIndices = new LinkedList();
    double percentageRemainder = (getPercentage() / 100)
        - Math.floor(getPercentage() / 100.0);
    int extraIndicesCount = (int) (percentageRemainder * sample.numInstances());
    if (extraIndicesCount >= 1) {
      for (int i = 0; i < sample.numInstances(); i++) {
        extraIndices.add(i);
      }
    }
    Collections.shuffle(extraIndices, rand);
    extraIndices = extraIndices.subList(0, extraIndicesCount);
    Set extraIndexSet = new HashSet(extraIndices);

    // the main loop to handle computing nearest neighbors and generating SMOTE
    // examples from each instance in the original minority class data
    Instance[] nnArray = new Instance[nearestNeighbors];
    for (int i = 0; i < sample.numInstances(); i++) {
      Instance instanceI = sample.instance(i);
      // find k nearest neighbors for each instance
      List distanceToInstance = new LinkedList();
      for (int j = 0; j < sample.numInstances(); j++) {
        Instance instanceJ = sample.instance(j);
        if (i != j) {
          double distance = 0;
          attrEnum = getInputFormat().enumerateAttributes();
          while (attrEnum.hasMoreElements()) {
            Attribute attr = (Attribute) attrEnum.nextElement();
            if (!attr.equals(getInputFormat().classAttribute())) {
              double iVal = instanceI.value(attr);
              double jVal = instanceJ.value(attr);
              if (attr.isNumeric()) {
                distance += Math.pow(iVal - jVal, 2);
              } else {
                distance += ((double[][]) vdmMap.get(attr))[(int) iVal][(int) jVal];
              }
            }
          }
          distance = Math.pow(distance, .5);
          distanceToInstance.add(new Object[] { distance, instanceJ });
        }
      }

      // sort the neighbors according to distance
      Collections.sort(distanceToInstance, new Comparator() {
        public int compare(Object o1, Object o2) {
          double distance1 = (Double) ((Object[]) o1)[0];
          double distance2 = (Double) ((Object[]) o2)[0];
          return Double.compare(distance1, distance2);
        }
      });

      // populate the actual nearest neighbor instance array
      Iterator entryIterator = distanceToInstance.iterator();
      int j = 0;
      while (entryIterator.hasNext() && j < nearestNeighbors) {
        nnArray[j] = (Instance) ((Object[]) entryIterator.next())[1];
        j++;
      }

      // create synthetic examples
      int n = (int) Math.floor(getPercentage() / 100);
      while (n > 0 || extraIndexSet.remove(i)) {
        double[] values = new double[sample.numAttributes()];
        int nn = rand.nextInt(nearestNeighbors);
        attrEnum = getInputFormat().enumerateAttributes();
        while (attrEnum.hasMoreElements()) {
          Attribute attr = (Attribute) attrEnum.nextElement();
          if (!attr.equals(getInputFormat().classAttribute())) {
            if (attr.isNumeric()) {
              double dif = nnArray[nn].value(attr) - instanceI.value(attr);
              double gap = rand.nextDouble();
              values[attr.index()] = (instanceI.value(attr) + gap * dif);
            } else if (attr.isDate()) {
              double dif = nnArray[nn].value(attr) - instanceI.value(attr);
              double gap = rand.nextDouble();
              values[attr.index()] = (long) (instanceI.value(attr) + gap * dif);
            } else {
              int[] valueCounts = new int[attr.numValues()];
              int iVal = (int) instanceI.value(attr);
              valueCounts[iVal]++;
              for (int nnEx = 0; nnEx < nearestNeighbors; nnEx++) {
                int val = (int) nnArray[nnEx].value(attr);
                valueCounts[val]++;
              }
              int maxIndex = 0;
              int max = Integer.MIN_VALUE;
              for (int index = 0; index < attr.numValues(); index++) {
                if (valueCounts[index] > max) {
                  max = valueCounts[index];
                  maxIndex = index;
                }
              }
              values[attr.index()] = maxIndex;
            }
          }
        }
        values[sample.classIndex()] = minIndex;
        Instance synthetic = new Instance(1.0, values);
        push(synthetic);
        n--;
      }
    }
  }

  /**
   * Main method for running this filter.
   * 
   * @param args should contain arguments to the filter: use -h for help
   */
  public static void main(String[] args) {
    runFilter(new SMOTE(), args);
  }
}
